//! The codegen module provides common functions and data structures used by multiple backends
//! during the code generation process.
#[cfg(unix)]
use crate::fault::FaultInfo;
use crate::{
    backend::RunnableModule,
    backend::{CacheGen, Compiler, CompilerConfig, Features, Token},
    cache::{Artifact, Error as CacheError},
    error::{CompileError, CompileResult, RuntimeError},
    module::{ModuleInfo, ModuleInner},
    structures::Map,
    types::{FuncIndex, FuncSig, SigIndex},
};
use smallvec::SmallVec;
use std::collections::HashMap;
use std::fmt;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::sync::{Arc, RwLock};
use wasmparser::{self, WasmDecoder};
use wasmparser::{Operator, Type as WpType};

/// A type that defines a function pointer, which is called when breakpoints occur.
pub type BreakpointHandler =
    Box<dyn Fn(BreakpointInfo) -> Result<(), RuntimeError> + Send + Sync + 'static>;

/// Maps instruction pointers to their breakpoint handlers.
pub type BreakpointMap = Arc<HashMap<usize, BreakpointHandler>>;

/// An event generated during parsing of a wasm binary
#[derive(Debug)]
pub enum Event<'a, 'b> {
    /// An internal event created by the parser used to provide hooks during code generation.
    Internal(InternalEvent),
    /// An event generated by parsing a wasm operator
    Wasm(&'b Operator<'a>),
    /// An event generated by parsing a wasm operator that contains an owned `Operator`
    WasmOwned(Operator<'a>),
}

/// Kinds of `InternalEvent`s created during parsing.
pub enum InternalEvent {
    /// A function parse is about to begin.
    FunctionBegin(u32),
    /// A function parsing has just completed.
    FunctionEnd,
    /// A breakpoint emitted during parsing.
    Breakpoint(BreakpointHandler),
    /// Indicates setting an internal field.
    SetInternal(u32),
    /// Indicates getting an internal field.
    GetInternal(u32),
}

impl fmt::Debug for InternalEvent {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            InternalEvent::FunctionBegin(_) => write!(f, "FunctionBegin"),
            InternalEvent::FunctionEnd => write!(f, "FunctionEnd"),
            InternalEvent::Breakpoint(_) => write!(f, "Breakpoint"),
            InternalEvent::SetInternal(_) => write!(f, "SetInternal"),
            InternalEvent::GetInternal(_) => write!(f, "GetInternal"),
        }
    }
}

/// Type representing an area of Wasm code in bytes as an offset from the
/// beginning of the code section.
///
/// `start` must be less than or equal to `end`.
#[derive(Copy, Clone, Debug)]
pub struct WasmSpan {
    /// Start offset in bytes from the beginning of the Wasm code section
    start: u32,
    /// End offset in bytes from the beginning of the Wasm code section
    end: u32,
}

impl WasmSpan {
    /// Create a new `WasmSpan`.
    ///
    /// `start` must be less than or equal to `end`.
    // TODO: mark this function as `const` when asserts get stabilized as `const`
    // see: https://github.com/rust-lang/rust/issues/57563
    pub fn new(start: u32, end: u32) -> Self {
        debug_assert!(start <= end);
        Self { start, end }
    }

    /// Start offset in bytes from the beginning of the Wasm code section
    pub const fn start(&self) -> u32 {
        self.start
    }

    /// End offset in bytes from the beginning of the Wasm code section
    pub const fn end(&self) -> u32 {
        self.end
    }

    /// Size in bytes of the span
    pub const fn size(&self) -> u32 {
        self.end - self.start
    }
}

/// Information for a breakpoint
#[cfg(unix)]
pub struct BreakpointInfo<'a> {
    /// Fault.
    pub fault: Option<&'a FaultInfo>,
}

/// Information for a breakpoint
#[cfg(not(unix))]
pub struct BreakpointInfo {
    /// Fault placeholder.
    pub fault: Option<()>,
}

/// A trait that represents the functions needed to be implemented to generate code for a module.
pub trait ModuleCodeGenerator<FCG: FunctionCodeGenerator<E>, RM: RunnableModule, E: Debug> {
    /// Creates a new module code generator.
    fn new() -> Self;

    /// Creates a new module code generator for specified target.
    fn new_with_target(
        triple: Option<String>,
        cpu_name: Option<String>,
        cpu_features: Option<String>,
    ) -> Self;

    /// Returns the backend id associated with this MCG.
    fn backend_id() -> &'static str;

    /// It sets if the current compiler requires validation before compilation
    fn requires_pre_validation() -> bool {
        true
    }

    /// Feeds the compiler config.
    fn feed_compiler_config(&mut self, _config: &CompilerConfig) -> Result<(), E> {
        Ok(())
    }
    /// Adds an import function.
    fn feed_import_function(&mut self, _sigindex: SigIndex) -> Result<(), E>;
    /// Sets the signatures.
    fn feed_signatures(&mut self, signatures: Map<SigIndex, FuncSig>) -> Result<(), E>;
    /// Sets function signatures.
    fn feed_function_signatures(&mut self, assoc: Map<FuncIndex, SigIndex>) -> Result<(), E>;
    /// Checks the precondition for a module.
    fn check_precondition(&mut self, module_info: &ModuleInfo) -> Result<(), E>;
    /// Creates a new function and returns the function-scope code generator for it.
    fn next_function(
        &mut self,
        module_info: Arc<RwLock<ModuleInfo>>,
        loc: WasmSpan,
    ) -> Result<&mut FCG, E>;
    /// Finalizes this module.
    fn finalize(
        self,
        module_info: &ModuleInfo,
    ) -> Result<(RM, Option<DebugMetadata>, Box<dyn CacheGen>), E>;

    /// Creates a module from cache.
    unsafe fn from_cache(cache: Artifact, _: Token) -> Result<ModuleInner, CacheError>;
}

/// Mock item when compiling without debug info generation.
#[cfg(not(feature = "generate-debug-information"))]
type CompiledFunctionData = ();

/// Mock item when compiling without debug info generation.
#[cfg(not(feature = "generate-debug-information"))]
type ValueLabelsRangesInner = ();

#[cfg(feature = "generate-debug-information")]
use wasm_debug::types::{CompiledFunctionData, ValueLabelsRangesInner};

#[derive(Clone, Debug)]
/// Useful information for debugging gathered by compiling a Wasm module.
pub struct DebugMetadata {
    /// [`CompiledFunctionData`] in [`FuncIndex`] order
    pub func_info: Map<FuncIndex, CompiledFunctionData>,
    /// [`ValueLabelsRangesInner`] in [`FuncIndex`] order
    pub inst_info: Map<FuncIndex, ValueLabelsRangesInner>,
    /// Stack slot offsets in [`FuncIndex`] order
    pub stack_slot_offsets: Map<FuncIndex, Vec<Option<i32>>>,
    /// function pointers and their lengths
    pub pointers: Vec<(*const u8, usize)>,
}

/// A streaming compiler which is designed to generated code for a module based on a stream
/// of wasm parser events.
pub struct StreamingCompiler<
    MCG: ModuleCodeGenerator<FCG, RM, E>,
    FCG: FunctionCodeGenerator<E>,
    RM: RunnableModule + 'static,
    E: Debug,
    CGEN: Fn() -> MiddlewareChain,
> {
    middleware_chain_generator: CGEN,
    _phantom_mcg: PhantomData<MCG>,
    _phantom_fcg: PhantomData<FCG>,
    _phantom_rm: PhantomData<RM>,
    _phantom_e: PhantomData<E>,
}

/// A simple generator for a `StreamingCompiler`.
pub struct SimpleStreamingCompilerGen<
    MCG: ModuleCodeGenerator<FCG, RM, E>,
    FCG: FunctionCodeGenerator<E>,
    RM: RunnableModule + 'static,
    E: Debug,
> {
    _phantom_mcg: PhantomData<MCG>,
    _phantom_fcg: PhantomData<FCG>,
    _phantom_rm: PhantomData<RM>,
    _phantom_e: PhantomData<E>,
}

impl<
        MCG: ModuleCodeGenerator<FCG, RM, E>,
        FCG: FunctionCodeGenerator<E>,
        RM: RunnableModule + 'static,
        E: Debug,
    > SimpleStreamingCompilerGen<MCG, FCG, RM, E>
{
    /// Create a new `StreamingCompiler`.
    pub fn new() -> StreamingCompiler<MCG, FCG, RM, E, impl Fn() -> MiddlewareChain> {
        StreamingCompiler::new(|| MiddlewareChain::new())
    }
}

impl<
        MCG: ModuleCodeGenerator<FCG, RM, E>,
        FCG: FunctionCodeGenerator<E>,
        RM: RunnableModule + 'static,
        E: Debug,
        CGEN: Fn() -> MiddlewareChain,
    > StreamingCompiler<MCG, FCG, RM, E, CGEN>
{
    /// Create a new `StreamingCompiler` with the given `MiddlewareChain`.
    pub fn new(chain_gen: CGEN) -> Self {
        Self {
            middleware_chain_generator: chain_gen,
            _phantom_mcg: PhantomData,
            _phantom_fcg: PhantomData,
            _phantom_rm: PhantomData,
            _phantom_e: PhantomData,
        }
    }
}

/// Create a new `ValidatingParserConfig` with the given features.
pub fn validating_parser_config(features: &Features) -> wasmparser::ValidatingParserConfig {
    wasmparser::ValidatingParserConfig {
        operator_config: wasmparser::OperatorValidatorConfig {
            enable_threads: features.threads,
            enable_reference_types: false,
            enable_simd: features.simd,
            enable_bulk_memory: false,
            enable_multi_value: false,

            #[cfg(feature = "deterministic-execution")]
            deterministic_only: true,
        },
    }
}

fn validate_with_features(bytes: &[u8], features: &Features) -> CompileResult<()> {
    let mut parser =
        wasmparser::ValidatingParser::new(bytes, Some(validating_parser_config(features)));
    loop {
        let state = parser.read();
        match *state {
            wasmparser::ParserState::EndWasm => break Ok(()),
            wasmparser::ParserState::Error(ref err) => Err(CompileError::ValidationError {
                msg: err.message().to_string(),
            })?,
            _ => {}
        }
    }
}

impl<
        MCG: ModuleCodeGenerator<FCG, RM, E>,
        FCG: FunctionCodeGenerator<E>,
        RM: RunnableModule + 'static,
        E: Debug,
        CGEN: Fn() -> MiddlewareChain,
    > Compiler for StreamingCompiler<MCG, FCG, RM, E, CGEN>
{
    #[allow(unused_variables)]
    fn compile(
        &self,
        wasm: &[u8],
        compiler_config: CompilerConfig,
        _: Token,
    ) -> CompileResult<ModuleInner> {
        if MCG::requires_pre_validation() {
            validate_with_features(wasm, &compiler_config.features)?;
        }

        let mut mcg = match MCG::backend_id() {
            "llvm" => MCG::new_with_target(
                compiler_config.triple.clone(),
                compiler_config.cpu_name.clone(),
                compiler_config.cpu_features.clone(),
            ),
            _ => MCG::new(),
        };
        let mut chain = (self.middleware_chain_generator)();
        let info = crate::parse::read_module(wasm, &mut mcg, &mut chain, &compiler_config)?;
        let (exec_context, compile_debug_info, cache_gen) = mcg
            .finalize(&info.read().unwrap())
            .map_err(|x| CompileError::InternalError {
                msg: format!("{:?}", x),
            })?;

        #[cfg(feature = "generate-debug-information")]
        {
            if compiler_config.should_generate_debug_info() {
                if let Some(dbg_info) = compile_debug_info {
                    let debug_info = wasm_debug::read_debuginfo(wasm);
                    let extra_info = wasm_debug::types::ModuleVmctxInfo::new(
                        crate::vm::Ctx::offset_memory_base() as _,
                        std::mem::size_of::<crate::vm::Ctx>() as _,
                        dbg_info.stack_slot_offsets.values(),
                    );
                    let compiled_fn_map =
                        wasm_debug::types::create_module_address_map(dbg_info.func_info.values());
                    let range_map =
                        wasm_debug::types::build_values_ranges(dbg_info.inst_info.values());
                    let raw_func_slice = &dbg_info.pointers;

                    let debug_image = wasm_debug::emit_debugsections_image(
                        target_lexicon::HOST,
                        std::mem::size_of::<usize>() as u8,
                        &debug_info,
                        &extra_info,
                        &compiled_fn_map,
                        &range_map,
                        raw_func_slice,
                    )
                    .expect("make debug image");

                    let mut writer = info.write().unwrap();
                    writer
                        .debug_info_manager
                        .register_new_jit_code_entry(&debug_image);
                }
            }
        }

        Ok(ModuleInner {
            cache_gen,
            runnable_module: Arc::new(Box::new(exec_context)),
            info: Arc::try_unwrap(info).unwrap().into_inner().unwrap(),
        })
    }

    unsafe fn from_cache(
        &self,
        artifact: Artifact,
        token: Token,
    ) -> Result<ModuleInner, CacheError> {
        MCG::from_cache(artifact, token)
    }
}

/// A sink for parse events.
pub struct EventSink<'a, 'b> {
    buffer: SmallVec<[Event<'a, 'b>; 2]>,
}

impl<'a, 'b> EventSink<'a, 'b> {
    /// Push a new `Event` to this sink.
    pub fn push(&mut self, ev: Event<'a, 'b>) {
        self.buffer.push(ev);
    }
}

/// A container for a chain of middlewares.
pub struct MiddlewareChain {
    chain: Vec<Box<dyn GenericFunctionMiddleware>>,
}

impl MiddlewareChain {
    /// Create a new empty `MiddlewareChain`.
    pub fn new() -> MiddlewareChain {
        MiddlewareChain { chain: vec![] }
    }

    /// Push a new `FunctionMiddleware` to this `MiddlewareChain`.
    pub fn push<M: FunctionMiddleware + 'static>(&mut self, m: M) {
        self.chain.push(Box::new(m));
    }

    /// Run this chain with the provided function code generator, event and module info.
    pub(crate) fn run<E: Debug, FCG: FunctionCodeGenerator<E>>(
        &mut self,
        fcg: Option<&mut FCG>,
        ev: Event,
        module_info: &ModuleInfo,
        source_loc: u32,
    ) -> Result<(), String> {
        let mut sink = EventSink {
            buffer: SmallVec::new(),
        };
        sink.push(ev);
        for m in &mut self.chain {
            let prev: SmallVec<[Event; 2]> = sink.buffer.drain(..).collect();
            for ev in prev {
                m.feed_event(ev, module_info, &mut sink, source_loc)?;
            }
        }
        if let Some(fcg) = fcg {
            for ev in sink.buffer {
                fcg.feed_event(ev, module_info, source_loc)
                    .map_err(|x| format!("{:?}", x))?;
            }
        }

        Ok(())
    }
}

/// A trait that represents the signature required to implement middleware for a function.
pub trait FunctionMiddleware {
    /// The error type for this middleware's functions.
    type Error: Debug;
    /// Processes the given event, module info and sink.
    fn feed_event<'a, 'b: 'a>(
        &mut self,
        op: Event<'a, 'b>,
        module_info: &ModuleInfo,
        sink: &mut EventSink<'a, 'b>,
        source_loc: u32,
    ) -> Result<(), Self::Error>;
}

pub(crate) trait GenericFunctionMiddleware {
    fn feed_event<'a, 'b: 'a>(
        &mut self,
        op: Event<'a, 'b>,
        module_info: &ModuleInfo,
        sink: &mut EventSink<'a, 'b>,
        source_loc: u32,
    ) -> Result<(), String>;
}

impl<E: Debug, T: FunctionMiddleware<Error = E>> GenericFunctionMiddleware for T {
    fn feed_event<'a, 'b: 'a>(
        &mut self,
        op: Event<'a, 'b>,
        module_info: &ModuleInfo,
        sink: &mut EventSink<'a, 'b>,
        source_loc: u32,
    ) -> Result<(), String> {
        <Self as FunctionMiddleware>::feed_event(self, op, module_info, sink, source_loc)
            .map_err(|x| format!("{:?}", x))
    }
}

/// The function-scope code generator trait.
pub trait FunctionCodeGenerator<E: Debug> {
    /// Sets the return type.
    fn feed_return(&mut self, ty: WpType) -> Result<(), E>;

    /// Adds a parameter to the function.
    fn feed_param(&mut self, ty: WpType) -> Result<(), E>;

    /// Adds `n` locals to the function.
    fn feed_local(&mut self, ty: WpType, n: usize, loc: u32) -> Result<(), E>;

    /// Called before the first call to `feed_opcode`.
    fn begin_body(&mut self, module_info: &ModuleInfo) -> Result<(), E>;

    /// Called for each operator.
    fn feed_event(&mut self, op: Event, module_info: &ModuleInfo, source_loc: u32)
        -> Result<(), E>;

    /// Finalizes the function.
    fn finalize(&mut self) -> Result<(), E>;
}
